import pandas as pd
import numpy as np
from itertools import combinations
from tqdm import tqdm
from typing import List, Dict, Any, Optional, Tuple
from sklearn.preprocessing import LabelEncoder
from hypersense.importance.base_analyzer import BaseImportanceAnalyzer
from hashlib import md5


class FANOVAAnalyzer(BaseImportanceAnalyzer):
    def __init__(self, max_subset_size: Optional[int] = None):
        """
        Functional ANOVA based hyperparameter importance analyzer.

        Args:
            max_subset_size: Maximum size of parameter combinations to consider.
                             If None, Maximum size of parameter combinations should be 2,
                             which means we only evaluate single-parameter and Two-parameter combination.
        """
        super().__init__()
        self.max_subset_size = max_subset_size
        self.metric_name: Optional[str] = None

    def fit(self, configs: List[Dict[str, Any]], scores: List[float]) -> None:
        """
        Fit the fANOVA analyzer on given hyperparameter configurations and scores.
        """

        df = pd.DataFrame(configs)
        self.label_encoders = {}
        for col in df.columns:
            if df[col].dtype == "object":
                le = LabelEncoder()
                df[col] = le.fit_transform(df[col])
                self.label_encoders[col] = le
        df["_metric_"] = scores
        self.metric_name = "_metric_"
        self.feature_importances_, self.interaction_importances_ = self._analyze(df)

    def explain(self) -> Dict[str, float]:
        if self.feature_importances_ is None:
            raise ValueError("Importance scores not available. Call fit() first.")
        return self.feature_importances_

    def explain_interactions(self) -> Dict[Tuple[str, str], float]:
        if self.interaction_importances_ is None:
            return {}
        return self.interaction_importances_

    def rank(self) -> List[str]:
        if self.feature_importances_ is None:
            raise ValueError("Importance scores not available. Call fit() first.")
        return sorted(
            self.feature_importances_,
            key=lambda k: self.feature_importances_[k],
            reverse=True,
        )

    def _analyze(
        self, cal_df: pd.DataFrame
    ) -> Tuple[Dict[str, float], Dict[Tuple[str, str], float]]:
        params = list(cal_df.columns)
        params.remove(self.metric_name)
        metric = self.metric_name

        f_table = pd.DataFrame(columns=["id"] + params + [metric])
        mean_score = cal_df[metric].mean()
        f_table.loc[0] = [self._stable_hash([])] + [np.nan] * len(params) + [mean_score]

        v_table = []

        max_subset = self.max_subset_size if self.max_subset_size is not None else 2
        max_iter = sum(
            len(list(combinations(params, k))) for k in range(1, max_subset + 1)
        )
        with tqdm(total=max_iter, desc="Fitting fANOVA") as pbar:
            for k in range(1, max_subset + 1):
                for u in combinations(params, k):
                    ne_u = set(params) - set(u)
                    a_u = self._get_a_u(cal_df, u, ne_u)
                    sum_f_w = self._get_sum_f_w(f_table, u)
                    f_u = self._get_f_u(a_u, sum_f_w, u)

                    f_table = pd.concat([f_table, f_u], ignore_index=True)

                    importance = (f_u[metric].to_numpy() ** 2).mean()

                    v_table.append((u, importance))
                    pbar.update(1)

        # Normalize importances
        total_weight = sum(raw_weight for _, raw_weight in v_table)
        feature_importances = {}
        interaction_importances = {}
        for name_tuple, raw_weight in v_table:
            norm_weight = raw_weight / total_weight if total_weight > 0 else 0.0
            if len(name_tuple) == 1:
                feature_importances[name_tuple[0]] = norm_weight
            elif len(name_tuple) == 2:
                interaction_importances[name_tuple] = norm_weight

        return feature_importances, interaction_importances

    def _get_a_u(self, cal_df, u, ne_u):
        numeric_cols = cal_df.select_dtypes(include=[np.number]).columns.tolist()
        numeric_cols = [col for col in numeric_cols if col not in u]
        a_u = cal_df.groupby(list(u))[numeric_cols].mean().reset_index()

        for nu in ne_u:
            a_u[nu] = np.nan

        ordered_columns = list(u) + list(ne_u) + [self.metric_name]
        a_u = a_u.reindex(columns=ordered_columns)
        a_u.insert(0, "id", np.nan)

        return a_u

    def _get_sum_f_w(self, f_table, u):
        subset_list = self._all_nonempty_subsets(u)
        sum_f_w = pd.DataFrame(columns=f_table.columns)
        for subset in subset_list:
            subset_id = self._stable_hash(subset)
            subset_f = f_table[f_table["id"] == subset_id]
            subset_f_clean = subset_f.dropna(axis=1, how="all")
            sum_f_w_clean = sum_f_w.dropna(axis=1, how="all")
            sum_f_w = pd.concat([sum_f_w_clean, subset_f_clean], ignore_index=True)

        return sum_f_w

    def _get_f_u(self, a_u, sum_f_w, u):
        metric = self.metric_name

        for _, row in sum_f_w.iterrows():
            subset_cols = [
                c for c in row.index if pd.notnull(row[c]) and c not in ["id", metric]
            ]

            if not subset_cols:
                a_u[metric] -= row[metric]
            else:
                comparison_results = []
                for col in subset_cols:
                    if pd.api.types.is_numeric_dtype(a_u[col]):
                        comp = np.isclose(a_u[col], row[col], equal_nan=True)
                    else:
                        comp = a_u[col] == row[col]
                    comparison_results.append(comp)
                match_idx = np.logical_and.reduce(comparison_results)

                a_u.loc[match_idx, metric] -= row[metric]

        a_u["id"] = self._stable_hash(u)
        return a_u

    def _all_nonempty_subsets(self, s):
        subsets = []
        for i in range(1, len(s)):
            subsets.extend(list(combinations(s, i)))
        return subsets

    def _stable_hash(self, obj) -> str:
        return md5(str(sorted(obj)).encode("utf-8")).hexdigest()
